
# RePaFormer: Reparameterizable Vision Transformer

This repository contains PyTorch evaluation code, training code and pretrained models for __RePaFormer__.

## Setup

First, clone the repository locally:
```
git clone https://github.com/**************/RePaFormer.git (Anonymous for submission)
cd RePaFormer
```
Then, install environments via Anaconda:
```
conda create -n repaformer python=3.10.14 -y
conda activate repaformer
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
pip install timm==1.0.3 einops ptflops wandb
```
After the above installations, it is ready to run this repo. 

We further utilize the [wandb](https://wandb.ai/site) for real-time track and training process visualization. The use of wandb is optional. However, you will need to register and login to wandb before using this functionality.

## Dataset preparation

Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively:
```
/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class2/
      img4.jpeg
```

## Training
### 1. Ordinary training on a single node
To train RePaFormers on ImageNet on a single node with 8 gpus for 300 epochs run:

RePa-DeiT-Base
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --output_dir=output/repadeit_base --model RePaViT_Base --feature_norm=BatchNorm --channel_idle
```

RePa-Swin-Base
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --output_dir=output/repaswin_base --model RePaSwin_Base --feature_norm=BatchNorm --channel_idle
```
Please note that `--channel_idle` argument must be used with `--feature_norm=BatchNorm`. More models can be found in files 'repaxxxx.py'.

### 2. Track your training with wandb
To train with wandb visualization:
```
WANDB_MODE=online python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --output_dir=output/repadeit_base --model RePaViT_Base_patch16_224_layer12 --feature_norm=BatchNorm --channel_idle --use_wandb
```
Please note that the environment variable `WANDB_MODE` MUST be set when using `--use_wandb`. You can choose `WANDB_MODE=online` for real-time tracking on the wandb dashboard, or `WANDB_MODE=offline` for local tracking and synchronizing later. 

### 3. Training with different idle ratios
To train with different idle ratios, set the `--idle_ratio` argument. For example, set idle ratio to 0.5:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --output_dir=output/repaswin_base --model RePaSwin_Base --feature_norm=BatchNorm --channel_idle --idle_ratio=0.5
```

## Evaluation
### 1. Performance evaluation
To test the performance on the testing set, for instance:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --model RePaSwin_Base --batch-size=128 --feature_norm=BatchNorm --channel_idle --reparam --eval --resume /path/to/checkpoint
```
where `--reparam` is the flag controlling whether reparamaterizing the model or not. Without `--reparam`, the vanilla backbone will be running.

### 2. Speed evaluation
To test the throughput, add `--test_speed` and `--only_test_speed`:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 --use_env main.py --data-path /path/to/imagenet --model RePaSwin_Base --batch-size=128 --feature_norm=BatchNorm --channel_idle --reparam --test_speed --only_test_speed
```

## License
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
